RA-DIT RETRIEVAL-AUGMENTED DUAL INSTRUCTION TUNING

RA-DIT
LLM
RAG
Author

Ivan Jacobs

Published

December 29, 2023

Retrieval-Augmented Dual Instruction Tuning (RA-DIT)

In this post, I explain the approach called Retrieval-Augmented Dual Instruction Tuning (RA-DIT) (Lin et al., n.d.) that leverages both language model (LM) fine-tuning and retriever fine-tuning to significantly boost the performance of RALMs, especially in scenarios that require access to large, external knowledge sources. The approach is based on the idea of fine-tuning a pre-trained LLM and a state-of-the-art dense retriever together, using a combination of supervised and unsupervised objectives.

The authors use a retrieval-augmenting pre-trained auto-regressive language model, specifically LLAMA, in combination with a dual-encoder based retriever architecture. The retriever is initialized using DRAGON+, a state-of-the-art dual-encoder model. Let’s break down the key components and equations mentioned in the passage.

  1. Language Model (LLAMA):

    • The language model used is LLAMA, a family of open-sourced language models pre-trained on trillions of tokens (Brown et al., 2020).
    • It is a retrieval-augmenting pre-trained auto-regressive language model.
  2. Retriever Architecture:

    • The retriever adopts a dual-encoder based architecture, known for its ease of fine-tuning and efficiency at the inference stage (Lewis et al., 2020; Izacard et al., 2022b; Shi et al., 2023b).
  3. Document and Query Encoders:

    • Given a corpus \(C\) and a query \(q\), the document encoder maps each text chunk \(c\) in \(C\) to an embedding \(Ed(c)\).
    • The query encoder maps the query \(q\) to an embedding \(Eq(q)\).
  4. Retrieval Process:

    • The top-k relevant text chunks for the query \(q\) are retrieved based on the query-document embedding similarity.

    • The similarity score \(s(q, c)\) between the query \(q\) and a document chunk \(c\) is computed using the dot product of their respective embeddings: \[s(q, c) = Eq(q) \cdot Ed(c)\] Let’s break this down and explain how the similarity \(s(q, c)\) is computed using a dot product, and provide a brief proof for the equation \(s(q, c) = Eq(q) \cdot Ed(c)\).
      The dot product of two vectors \(A = [a_1, a_2, ..., a_n]\) and \(B = [b_1, b_2, ..., b_n]\) is given by: \(A \cdot B = a_1 \cdot b_1 + a_2 \cdot b_2 + \ldots + a_n \cdot b_n\)

      The dot product is used to compute the similarity between the query embedding \(Eq(q)\) and the document embedding \(Ed(c)\). The embeddings are vectors in a high-dimensional space. It is a measure of the cosine of the angle between two vectors, and in the context of embeddings, it helps capture the semantic similarity between the query and document. Higher dot product values imply greater similarity, and the retriever uses this similarity score to retrieve relevant text chunks.

      The similarity \(s(q, c)\) is computed as the dot product between the query embedding vector \(Eq(q)\) and the document embedding vector \(Ed(c)\). This can be represented as: \(s(q, c) = Eq(q) \cdot Ed(c)\)

      Let \(Eq(q) = [eq_1, eq_2, ..., eq_k]\) be the components of the query embedding vector, and \(Ed(c) = [ed_1, ed_2, ..., ed_k]\) be the components of the document embedding vector. The dot product is then given by: \(s(q, c) = eq_1 \cdot ed_1 + eq_2 \cdot ed_2 + \ldots + eq_k \cdot ed_k\)

      This computation effectively measures the similarity between the query and document embeddings by multiplying their corresponding components and summing the results. If the vectors are similar, the dot product will be larger, indicating higher similarity. Conversely, if the vectors are dissimilar, the dot product will be smaller.

  5. Retriever Initialization:

    • The retriever is initialized using DRAGON+, a state-of-the-art dual-encoder model.
    • DRAGON+ is trained with a contrastive learning objective and large-scale data augmentation (Lin et al., 2023).

    In summary, the retriever uses dual encoders to map documents and queries into embeddings, and it retrieves relevant text chunks based on the similarity between the query and document embeddings. The initialization of the retriever is done using a high-performing model, DRAGON+, trained with a contrastive learning objective.

  6. Parallel In-Context Retrieval-Augmentation:

    • Parallel In-Context Retrieval-Augmentation technique, building on the work of Shi et al. (2023b). This technique involves retrieving top-k relevant text chunks for a given language model prompt and augmenting the prompt with each retrieved chunk. The language model predictions are then computed in parallel, and the final output probability is a mixture of the probabilities from each augmented prompt, weighted by the chunk relevance score.

      Let’s break down the key components and equations mentioned in the passage:

      Retrieval Process:

      • For a given language model prompt \(x\), top-k relevant text chunks \(C'\) are retrieved from the corpus \(C\). Each retrieved chunk is denoted as \(c\), and \(|C'| = k\).

      Augmentation:

      • To stay within the context window size limit, each retrieved chunk \(c\) is prepended individually to the prompt \(x\).

      Parallel Computation:

      • Language model predictions from multiple augmented prompts are computed in parallel.

      Output Probability Calculation:

      The final output probability \(p_{LM}(y|x, C')\) is a mixture of the probability from each augmented prompt, weighted by the chunk relevance score.

      \(p_{LM}(y|x, C') = \sum_{c \in C'} p_{LM}(y|c \circ x) \cdot p_{R}(c|x)\) (Equation 2)

      where:

      • \(\circ\) denotes sequence concatenation.
      • \(p_{LM}(y|c \circ x)\) is the probability of language model predictions for the augmented prompt.
      • \(p_{R}(c|x) = \frac{e^{s(x,c)}}{\sum_{c' \in C'} e^{s(x,c')}}\) is the chunk relevance score, which represents the retriever scores normalized among the top-k relevant chunks.

      Explanation: The equation sums up the probabilities of language model predictions for each augmented prompt, where each prompt is formed by concatenating a retrieved chunk \(c\) with the original prompt \(x\). - The probability \(p_{R}(c|x)\) is the chunk relevance score, representing how relevant the retrieved chunk \(c\) is to the original prompt \(x\). It is calculated by exponentiating the retriever scores and normalizing them among the top-k relevant chunks. - The final output probability is a weighted sum of these probabilities, where each term is weighted by the relevance of the corresponding retrieved chunk.

      In summary, the Parallel In-Context Retrieval-Augmentation technique involves retrieving relevant chunks, augmenting the prompt with each chunk, computing language model predictions in parallel, and combining the probabilities in a weighted manner based on chunk relevance scores.

RETRIEVAL AUGMENTED LANGUAGE MODEL FINE-TUNING

The approach involves incorporating in-context retrieval augmentation during the fine-tuning process. Let’s break down the key components and equations mentioned in the passage:

  1. Fine-Tuning Setup:
    • The language model is fine-tuned on selected datasets \(DL\) with in-context retrieval augmentation.
    • Each fine-tuning sequence is separated into an instruction segment \(x\) and an output segment \(y\).
  2. Retrieval of Relevant Text Chunks:
    • For each example \((x_i, y_i) \in DL\), the top-˜k relevant text chunks \(C_i\) are retrieved based on \(x_i\).
    • For each retrieved chunk \(c_{ij} \in C_i\), a separate fine-tuning example is created by prepending it to the instructions as a background field.
    • This results in ˜k independent fine-tuning instances per original example: \({(c_{ij} \circ x_i, y_i) | j = 1, ..., ˜k}\).
  3. Fine-Tuning Objective:
    • The language model is fine-tuned using the next-token prediction objective.
    • The loss function \(L(DL)\) is defined as the negative log likelihood of the language model predicting the output segment \(y_i\) given the augmented input \(c_{ij} \circ x_i\). \[L(DL) = - \sum_i \sum_j \log p_{LM}(y_i|c_{ij} \circ x_i)\] (Equation 3)
  4. Benefits of In-Context Retrieval Augmentation:
    • In-context retrieval augmentation during fine-tuning provides a twofold benefit.
      • First, it helps the language model better utilize relevant background knowledge to make predictions.

      • Second, it allows the language model to handle cases where even state-of-the-art retrievers may falter and return inaccurate results. Training the language model to make correct predictions in such cases enables it to ignore misleading retrieval content and rely on its parametric knowledge.

Here is a vanilla implementation of the fine-tuning process with in-context retrieval augmentation in PyTorch:

import torch
import torch.nn as nn
from torch.utils.data import Dataset, DataLoader
from transformers import AutoTokenizer, AutoModelForSequenceClassification, TrainingArguments
# Load the pre-trained language model and tokenizer
model = AutoModelForSequenceClassification.from_pretrained('bert-base-uncased')
tokenizer = AutoTokenizer.from_pretrained('bert-base-uncased')
# Define the dataset and data loader
class FineTuningDataset(Dataset):
    def __init__(self, texts, labels):
        self.texts = texts
        self.labels = labels
    def __len__(self):
        return len(self.texts)
    def __getitem__(self, index):
        text = self.texts[index]
        label = self.labels[index]
        return {
            'text': tokenizer.encode(text, return_tensors='pt'),
            'label': torch.tensor(label, dtype=torch.long)
        }
def get_relevant_text_chunks(texts, k):
    # Use a retrieval model to retrieve the top-k relevant text chunks
    # and return a set of vectors representing the relevant text chunks
    pass
# Fine-tune the language model
training_args = TrainingArguments(
    output_dir='./results',
    num_train_epochs=3,
    per_device_train_batch_size=16,
    per_device_eval_batch_size=64,
    evaluation_strategy='epoch',
    learning_rate=5e-5,
    save_total_limit=2,
    save_steps=500,
    load_best_model_at_end=True,
    metric_for_best_model='accuracy',
    greater_is_better=True,
    save_strategy='steps',
    save_on_each_node=True,
)
data_loader = DataLoader(FineTuningDataset(train_texts, train_labels), batch_size=training_args.per_device_train_batch_size, shuffle=True)
for epoch in range(training_args.num_train_epochs):
    model.train()
    for batch in data_loader:
        text = batch['text'].to(device)
        label = batch['label'].to(device)
        optimizer.zero_grad()
        outputs = model(text)
        loss = nn.CrossEntropyLoss()(outputs, label)
        loss.backward()
        optimizer.step()
    # Compute the loss for the fine-tuning instances
    total_loss = 0
    for batch in data_loader:
        text = batch['text'].to(device)
        label = batch['label'].to(device)
        outputs = model(text)
        loss = nn.CrossEntropyLoss()(outputs, label)
        total_loss += loss.item()
    print(f'Epoch {epoch+1}, Loss: {total_loss / len(data_loader)}')

In this implementation, we first load the pre-trained language model and tokenizer. We then define a custom dataset class FineTuningDataset to store the training data, and create a data loader from the dataset. We define a function get_relevant_text_chunks to retrieve the top-\(k\) relevant text chunks for each example, and use the retrieved chunks to create separate fine-tuning instances. We then fine-tune the language model using the next-token prediction objective, and compute the loss for the fine-tuning instances. Note that the implementation above is a simplified version of the original approach, and may not be directly applicable to all scenarios. The original approach may involve additional preprocessing steps, such as tokenization and padding, and may use different retrieval models and evaluation metrics.

Explanation:

The fine-tuning process involves extending each training example by incorporating relevant background information retrieved from the corpus. - The fine-tuning objective is to minimize the negative log likelihood of the language model’s predictions for the output segment, given the augmented input. - The strategy helps the language model adapt to utilize relevant background knowledge during prediction and handle cases where retrieved information might be incorrect or misleading.

In summary, the fine-tuning strategy with in-context retrieval augmentation enhances the language model’s ability to utilize retrieved information by adapting to relevant background knowledge and handling inaccuracies in retrieval content. The empirical efficacy of this approach is demonstrated in §5.1 of the document.

RETRIEVER FINE-TUNING The process of fine-tuning the retriever to align its output with the language model by leveraging a generalized version of LSR (LM-Supervised Retrieval) training. The goal is to use the language model itself to provide supervision for retriever fine-tuning. Let’s break down the key components and equations mentioned in the passage:

  1. Generalized LSR Training:
    • LSR (LM-Supervised Retrieval) training is a method that utilizes the language model to provide supervision for retriever fine-tuning.
    • A generalized version of LSR training is adopted.
  2. LSR Score Calculation:
    • For a training sample \((x, y)\) in the retriever fine-tuning dataset \(DR\), the LSR score for a retrieved chunk \(c\) is defined as: \(p_{LSR}(c|x, y) = \frac{\sum_{c' \in C} \exp\left(\frac{pLM(y|c' \circ x)}{\tau}\right)}{\sum_{c' \in C'} \exp\left(\frac{pLM(y|c' \circ x)}{\tau}\right)}\) (Equation 4) where:
      • \(\tau\) is a temperature hyperparameter.
      • \(C' \subset C\) denotes the top-k retrieved chunks for \(x\).
    • A higher LSR score indicates that \(c\) is more effective at improving the language model’s chance of predicting the correct answer.
  3. Training Objective:
    • The training objective is to minimize the Kullback-Leibler (KL) divergence between the retriever scores \(p_{R}(c|x)\) (defined in Eq. 2) and the LSR scores \(p_{LSR}(c|x, y)\): \[L(D_R) = \mathbb{E}_{(x, y) \in D_R} \text{KL}\left[p_{R}(c|x) \, \|\, p_{LSR}(c|x, y)\right]\] (Equation 5)
  4. Update Strategy:
    • In practice, only the query encoder of the retriever is updated during fine-tuning, as updating both encoders is found to hurt performance (as mentioned in §5.1).

Explanation:

LSR score computation involves normalizing the exponential of language model prediction scores for retrieved chunks by the temperature parameter \(\tau\).

The goal of LSR training is for the retriever to assign higher scores to chunks that can improve the language model’s likelihood of generating the correct answer. - The training objective, expressed as the KL divergence, measures the difference between the retriever scores and the LSR scores. Minimizing this divergence encourages the retriever to align its outputs with the language model’s predictions. - The update is performed only on the query encoder of the retriever during fine-tuning, as updating both encoders negatively impacts performance.

In summary, the passage describes a fine-tuning strategy for the retriever that incorporates LSR training, leveraging the language model for supervision. This approach aims to improve the alignment between the retriever’s output and the language model’s predictions.

References

Lin, Xi Victoria, Xilun Chen, Mingda Chen, Weijia Shi, Maria Lomeli, Rich James, Pedro Rodriguez, et al. n.d. “RA-DIT: Retrieval-Augmented Dual Instruction Tuning.” https://doi.org/10.48550/arXiv.2310.01352.